Skip to content

Conversation

@atrosinenko
Copy link
Contributor

@atrosinenko atrosinenko commented Apr 14, 2025

  • use more flexible ArrayRef<T> and StringRef types instead of
    const std::vector<T> & and const std::string &, correspondingly,
    for function arguments
  • return plain const SrcState & instead of ErrorOr<const SrcState &>
    from SrcSafetyAnalysis::getStateBefore, as absent state is not
    handled gracefully by any caller

Copy link
Contributor Author

atrosinenko commented Apr 14, 2025

@llvmbot
Copy link
Member

llvmbot commented Apr 14, 2025

@llvm/pr-subscribers-bolt

Author: Anatoly Trosinenko (atrosinenko)

Changes
  • use more flexible const ArrayRef&lt;T&gt; and StringRef types instead of
    const std::vector&lt;T&gt; &amp; and const std::string &amp;, correspondingly,
    for function arguments
  • return plain const SrcState &amp; instead of ErrorOr&lt;const SrcState &amp;&gt;
    from SrcSafetyAnalysis::getStateBefore, as absent state is not
    handled gracefully by any caller

Full diff: https://github.com/llvm/llvm-project/pull/135661.diff

2 Files Affected:

  • (modified) bolt/include/bolt/Passes/PAuthGadgetScanner.h (+2-6)
  • (modified) bolt/lib/Passes/PAuthGadgetScanner.cpp (+17-22)
diff --git a/bolt/include/bolt/Passes/PAuthGadgetScanner.h b/bolt/include/bolt/Passes/PAuthGadgetScanner.h
index 6765e2aff414f..3e39b64e59e0f 100644
--- a/bolt/include/bolt/Passes/PAuthGadgetScanner.h
+++ b/bolt/include/bolt/Passes/PAuthGadgetScanner.h
@@ -12,7 +12,6 @@
 #include "bolt/Core/BinaryContext.h"
 #include "bolt/Core/BinaryFunction.h"
 #include "bolt/Passes/BinaryPasses.h"
-#include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/raw_ostream.h"
 #include <memory>
 
@@ -199,9 +198,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);
 
 namespace PAuthGadgetScanner {
 
-class SrcSafetyAnalysis;
-struct SrcState;
-
 /// Description of a gadget kind that can be detected. Intended to be
 /// statically allocated to be attached to reports by reference.
 class GadgetKind {
@@ -210,7 +206,7 @@ class GadgetKind {
 public:
   GadgetKind(const char *Description) : Description(Description) {}
 
-  const StringRef getDescription() const { return Description; }
+  StringRef getDescription() const { return Description; }
 };
 
 /// Base report located at some instruction, without any additional information.
@@ -261,7 +257,7 @@ struct GadgetReport : public Report {
 /// Report with a free-form message attached.
 struct GenericReport : public Report {
   std::string Text;
-  GenericReport(MCInstReference Location, const std::string &Text)
+  GenericReport(MCInstReference Location, StringRef Text)
       : Report(Location), Text(Text) {}
   virtual void generateReport(raw_ostream &OS,
                               const BinaryContext &BC) const override;
diff --git a/bolt/lib/Passes/PAuthGadgetScanner.cpp b/bolt/lib/Passes/PAuthGadgetScanner.cpp
index ad47bdff753c8..ed89471cbb8d3 100644
--- a/bolt/lib/Passes/PAuthGadgetScanner.cpp
+++ b/bolt/lib/Passes/PAuthGadgetScanner.cpp
@@ -91,14 +91,14 @@ class TrackedRegisters {
   const std::vector<MCPhysReg> Registers;
   std::vector<uint16_t> RegToIndexMapping;
 
-  static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
+  static size_t getMappingSize(const ArrayRef<MCPhysReg> RegsToTrack) {
     if (RegsToTrack.empty())
       return 0;
     return 1 + *llvm::max_element(RegsToTrack);
   }
 
 public:
-  TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
+  TrackedRegisters(const ArrayRef<MCPhysReg> RegsToTrack)
       : Registers(RegsToTrack),
         RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
     for (unsigned I = 0; I < RegsToTrack.size(); ++I)
@@ -234,7 +234,7 @@ struct SrcState {
 
 static void printLastInsts(
     raw_ostream &OS,
-    const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) {
+    const ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) {
   OS << "Insts: ";
   for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) {
     auto &Set = LastInstWritingReg[I];
@@ -295,7 +295,7 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
 class SrcSafetyAnalysis {
 public:
   SrcSafetyAnalysis(BinaryFunction &BF,
-                    const std::vector<MCPhysReg> &RegsToTrackInstsFor)
+                    const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
       : BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
         RegsToTrackInstsFor(RegsToTrackInstsFor) {}
 
@@ -303,11 +303,10 @@ class SrcSafetyAnalysis {
 
   static std::shared_ptr<SrcSafetyAnalysis>
   create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
-         const std::vector<MCPhysReg> &RegsToTrackInstsFor);
+         const ArrayRef<MCPhysReg> RegsToTrackInstsFor);
 
   virtual void run() = 0;
-  virtual ErrorOr<const SrcState &>
-  getStateBefore(const MCInst &Inst) const = 0;
+  virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0;
 
 protected:
   BinaryContext &BC;
@@ -348,7 +347,7 @@ class SrcSafetyAnalysis {
   }
 
   BitVector getClobberedRegs(const MCInst &Point) const {
-    BitVector Clobbered(NumRegs, false);
+    BitVector Clobbered(NumRegs);
     // Assume a call can clobber all registers, including callee-saved
     // registers. There's a good chance that callee-saved registers will be
     // saved on the stack at some point during execution of the callee.
@@ -409,8 +408,7 @@ class SrcSafetyAnalysis {
 
       // FirstCheckerInst should belong to the same basic block, meaning
       // it was deterministically processed a few steps before this instruction.
-      const SrcState &StateBeforeChecker =
-          getStateBefore(*FirstCheckerInst).get();
+      const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst);
       if (StateBeforeChecker.SafeToDerefRegs[CheckedReg])
         Regs.push_back(CheckedReg);
     }
@@ -523,10 +521,7 @@ class SrcSafetyAnalysis {
                          const ArrayRef<MCPhysReg> UsedDirtyRegs) const {
     if (RegsToTrackInstsFor.empty())
       return {};
-    auto MaybeState = getStateBefore(Inst);
-    if (!MaybeState)
-      llvm_unreachable("Expected state to be present");
-    const SrcState &S = *MaybeState;
+    const SrcState &S = getStateBefore(Inst);
     // Due to aliasing registers, multiple registers may have been tracked.
     std::set<const MCInst *> LastWritingInsts;
     for (MCPhysReg TrackedReg : UsedDirtyRegs) {
@@ -537,7 +532,7 @@ class SrcSafetyAnalysis {
     for (const MCInst *Inst : LastWritingInsts) {
       MCInstReference Ref = MCInstReference::get(Inst, BF);
       assert(Ref && "Expected Inst to be found");
-      Result.push_back(MCInstReference(Ref));
+      Result.push_back(Ref);
     }
     return Result;
   }
@@ -557,11 +552,11 @@ class DataflowSrcSafetyAnalysis
 public:
   DataflowSrcSafetyAnalysis(BinaryFunction &BF,
                             MCPlusBuilder::AllocatorIdTy AllocId,
-                            const std::vector<MCPhysReg> &RegsToTrackInstsFor)
+                            const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
       : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}
 
-  ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
-    return DFParent::getStateBefore(Inst);
+  const SrcState &getStateBefore(const MCInst &Inst) const override {
+    return DFParent::getStateBefore(Inst).get();
   }
 
   void run() override {
@@ -670,7 +665,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
 public:
   CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF,
                               MCPlusBuilder::AllocatorIdTy AllocId,
-                              const std::vector<MCPhysReg> &RegsToTrackInstsFor)
+                              const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
       : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
     StateAnnotationIndex =
         BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis");
@@ -704,7 +699,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
     }
   }
 
-  ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
+  const SrcState &getStateBefore(const MCInst &Inst) const override {
     return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex);
   }
 
@@ -714,7 +709,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
 std::shared_ptr<SrcSafetyAnalysis>
 SrcSafetyAnalysis::create(BinaryFunction &BF,
                           MCPlusBuilder::AllocatorIdTy AllocId,
-                          const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
+                          const ArrayRef<MCPhysReg> RegsToTrackInstsFor) {
   if (BF.hasCFG())
     return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId,
                                                        RegsToTrackInstsFor);
@@ -821,7 +816,7 @@ Analysis::findGadgets(BinaryFunction &BF,
 
   BinaryContext &BC = BF.getBinaryContext();
   iterateOverInstrs(BF, [&](MCInstReference Inst) {
-    const SrcState &S = *Analysis->getStateBefore(Inst);
+    const SrcState &S = Analysis->getStateBefore(Inst);
 
     // If non-empty state was never propagated from the entry basic block
     // to Inst, assume it to be unreachable and report a warning.

Copy link
Collaborator

@kbeyls kbeyls left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-use-better-types branch from 51373db to d57dc48 Compare April 18, 2025 16:34
@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-signing-oracles branch from 56ea6bc to 57ca35a Compare April 22, 2025 16:08
@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-use-better-types branch from d57dc48 to cf7d310 Compare April 22, 2025 16:08
@atrosinenko
Copy link
Contributor Author

Sorry for updating an already approved PR, just spotted that const ArrayRef<T> is redundant: ArrayRef<T> is enough to replace const T* with size information - it is MutableArrayRef to be used in place of non-const T*. Updated #135662 and #135663 accordingly.

@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-use-better-types branch 2 times, most recently from 554bcfd to a5b966d Compare April 30, 2025 14:54
@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-signing-oracles branch from 10b2ace to 5cc58f2 Compare April 30, 2025 14:54
@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-signing-oracles branch from 5cc58f2 to 66db728 Compare May 16, 2025 17:10
@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-use-better-types branch 2 times, most recently from 5e2fd61 to eb984d9 Compare May 20, 2025 10:03
Base automatically changed from users/atrosinenko/bolt-gs-signing-oracles to main May 20, 2025 10:42
* use more flexible `const ArrayRef<T>` and `StringRef` types instead of
  `const std::vector<T> &` and `const std::string &`, correspondingly,
  for function arguments
* return plain `const SrcState &` instead of `ErrorOr<const SrcState &>`
  from `SrcSafetyAnalysis::getStateBefore`, as absent state is not
  handled gracefully by any caller
@atrosinenko atrosinenko force-pushed the users/atrosinenko/bolt-gs-use-better-types branch from eb984d9 to 62b27f9 Compare May 20, 2025 10:44
@atrosinenko atrosinenko merged commit 14706d6 into main May 20, 2025
10 checks passed
@atrosinenko atrosinenko deleted the users/atrosinenko/bolt-gs-use-better-types branch May 20, 2025 11:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants